# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

"""
DESCRIPTION:
    This sample demonstrates how to use agent operations with file searching from
    the Azure Agents service using a synchronous client.

USAGE:
    python sample_agents_stream_file_search_async.py

    Before running the sample:

    pip install azure-ai-projects azure-ai-agents azure-identity

    Set these environment variables with your own values:
    1) PROJECT_ENDPOINT - The Azure AI Project endpoint, as found in the Overview
                          page of your Azure AI Foundry portal.
    2) MODEL_DEPLOYMENT_NAME - The deployment name of the AI model, as found under the "Name" column in
       the "Models + endpoints" tab in your Azure AI Foundry project.
"""

import asyncio
import os
from azure.ai.projects.aio import AIProjectClient
from azure.ai.agents.models import (
    AgentStreamEvent,
    FilePurpose,
    FileSearchTool,
    ListSortOrder,
    MessageDeltaChunk,
    MessageDeltaTextContent,
    MessageDeltaTextFileCitationAnnotation,
    MessageDeltaTextFileCitationAnnotationObject,
    RunAdditionalFieldList,
    RunStep,
    RunStepDeltaChunk,
    RunStepFileSearchToolCall,
    RunStepToolCallDetails,
    ThreadMessage,
    ThreadRun,
)
from azure.identity.aio import DefaultAzureCredential

asset_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../assets/product_info_1.md"))

project_client = AIProjectClient(
    endpoint=os.environ["PROJECT_ENDPOINT"],
    credential=DefaultAzureCredential(),
)


async def main() -> None:
    project_client = AIProjectClient(
        endpoint=os.environ["PROJECT_ENDPOINT"],
        credential=DefaultAzureCredential(),
    )
    async with project_client:
        agents_client = project_client.agents

        # Upload file and create vector store
        file = await agents_client.files.upload_and_poll(file_path=asset_file_path, purpose=FilePurpose.AGENTS)
        print(f"Uploaded file, file ID: {file.id}")

        vector_store = await agents_client.vector_stores.create_and_poll(file_ids=[file.id], name="my_vectorstore")
        print(f"Created vector store, vector store ID: {vector_store.id}")

        # Create file search tool with resources followed by creating agent
        file_search = FileSearchTool(vector_store_ids=[vector_store.id])

        agent = await agents_client.create_agent(
            model=os.environ["MODEL_DEPLOYMENT_NAME"],
            name="my-agent",
            instructions="Hello, you are helpful agent and can search information from uploaded files",
            tools=file_search.definitions,
            tool_resources=file_search.resources,
        )

        print(f"Created agent, ID: {agent.id}")

        # Create thread for communication
        thread = await agents_client.threads.create()
        print(f"Created thread, ID: {thread.id}")

        # Create message to thread
        message = await agents_client.messages.create(
            thread_id=thread.id, role="user", content="Hello, what Contoso products do you know?"
        )
        print(f"Created message, ID: {message.id}")

        references = {}
        async with await agents_client.runs.stream(
            thread_id=thread.id, agent_id=agent.id, include=[RunAdditionalFieldList.FILE_SEARCH_CONTENTS]
        ) as stream:

            async for event_type, event_data, _ in stream:
                if isinstance(event_data, MessageDeltaChunk):
                    text = event_data.text
                    if event_data.delta.content and isinstance(event_data.delta.content[0], MessageDeltaTextContent):
                        delta_text_content = event_data.delta.content[0]
                        if delta_text_content.text and delta_text_content.text.annotations:
                            for delta_annotation in delta_text_content.text.annotations:
                                if (
                                    isinstance(delta_annotation, MessageDeltaTextFileCitationAnnotation)
                                    and isinstance(
                                        delta_annotation.file_citation, MessageDeltaTextFileCitationAnnotationObject
                                    )
                                    and delta_annotation.file_citation.file_id
                                ):
                                    citation = (
                                        os.path.split(asset_file_path)[-1]
                                        if delta_annotation.file_citation.file_id == file.id
                                        else delta_annotation.file_citation.file_id
                                    )
                                    if delta_annotation.text:
                                        references[delta_annotation.text] = citation
                                    print(f"File citation delta received: [{citation}]")
                    for ref, citation in references.items():
                        text = text.replace(ref, f" [{citation}]")
                    print(f"Text delta received: {text}")

                elif isinstance(event_data, RunStepDeltaChunk):
                    print(f"RunStepDeltaChunk received. ID: {event_data.id}.")

                elif isinstance(event_data, ThreadMessage):
                    print(f"ThreadMessage created. ID: {event_data.id}, Status: {event_data.status}")

                elif isinstance(event_data, ThreadRun):
                    print(f"ThreadRun status: {event_data.status}")

                    if event_data.status == "failed":
                        print(f"Run failed. Error: {event_data.last_error}")

                elif isinstance(event_data, RunStep):
                    print(f"RunStep type: {event_data.type}, Status: {event_data.status}")
                    if isinstance(event_data.step_details, RunStepToolCallDetails):
                        for tool_call in event_data.step_details.tool_calls:
                            if (
                                isinstance(tool_call, RunStepFileSearchToolCall)
                                and tool_call.file_search
                                and tool_call.file_search.results
                                and tool_call.file_search.results[0].content
                                and tool_call.file_search.results[0].content[0].text
                            ):
                                print(
                                    "The search tool has found the next relevant content in "
                                    f"the file {tool_call.file_search.results[0].file_name}:"
                                )
                                # Note: technically we may have several search results, however in our example
                                # we only have one file, so we are taking the only result.
                                print(tool_call.file_search.results[0].content[0].text)
                                print("===============================================================")

                elif event_type == AgentStreamEvent.ERROR:
                    print(f"An error occurred. Data: {event_data}")

                elif event_type == AgentStreamEvent.DONE:
                    print("Stream completed.")

                else:
                    print(f"Unhandled Event Type: {event_type}, Data: {event_data}")

        # Delete the file when done
        await agents_client.vector_stores.delete(vector_store.id)
        print("Deleted vector store")

        await agents_client.files.delete(file_id=file.id)
        print("Deleted file")

        # Delete the agent when done
        await agents_client.delete_agent(agent.id)
        print("Deleted agent")

        # Fetch and log all messages
        messages = agents_client.messages.list(thread_id=thread.id, order=ListSortOrder.ASCENDING)

        # Print last messages from the thread
        file_name = os.path.split(asset_file_path)[-1]
        async for msg in messages:
            if msg.text_messages:
                last_text = msg.text_messages[-1].text.value
                for annotation in msg.text_messages[-1].text.annotations:
                    citation = (
                        file_name if annotation.file_citation.file_id == file.id else annotation.file_citation.file_id
                    )
                    last_text = last_text.replace(annotation.text, f" [{citation}]")
                print(f"{msg.role}: {last_text}")


if __name__ == "__main__":
    asyncio.run(main())
